import hashlib
import json
import argparse
from typing import List

def hash_data(data: str) -> str:
    return hashlib.sha256(data.encode("utf-8")).hexdigest()

def verify_merkle_proof(index: int, leaf_value: str, proof: List[str], root: str, total_leaves: int) -> bool:
    """
    Verifies a Merkle proof for a leaf at a given index.
    """
    leaf_hash = hash_data(leaf_value)
    proof_index = index
    for sibling_hash in proof:
        if proof_index % 2 == 0:
            leaf_hash = hash_data(leaf_hash + sibling_hash)
        else:
            leaf_hash = hash_data(sibling_hash + leaf_hash)
        proof_index //= 2
    return leaf_hash == root

def generate_merkle_proof(tree: List[str], index: int, total_leaves: int) -> List[str]:
    """
    Generate the Merkle proof for the leaf at `index`.
    """
    proof = []
    node_index = (total_leaves - 1) + index
    while node_index > 0:
        sibling_index = node_index - 1 if node_index % 2 == 0 else node_index + 1
        if sibling_index < len(tree):
            proof.append(tree[sibling_index])
        node_index = (node_index - 1) // 2
    return proof

def verify_from_saved_tree(tree_path: str, index: int, value: str):
    with open(tree_path, "r", encoding="utf-8") as f:
        tree = json.load(f)

    total_nodes = len(tree)
    leaf_count = (total_nodes + 1) // 2
    proof = generate_merkle_proof(tree, index, leaf_count)
    root = tree[0]
    valid = verify_merkle_proof(index, value, proof, root, leaf_count)
    print(f"[✓] Merkle proof verification for index {index}: {'✅ VALID' if valid else '❌ INVALID'}")

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Verify a Merkle proof from a stored tree")
    parser.add_argument("--tree_path", type=str, required=True, help="Path to the saved Merkle tree JSON file")
    parser.add_argument("--index", type=int, required=True, help="Index of the leaf to verify")
    parser.add_argument("--value", type=str, required=True, help="Original value of the leaf node")

    args = parser.parse_args()
    verify_from_saved_tree(args.tree_path, args.index, args.value)
